Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove _supports_static_cache = True for some model classes #34975

Merged
merged 11 commits into from
Jan 28, 2025
Merged

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Nov 27, 2024

What does this PR do?

Remove _supports_static_cache = True for some model classes. See the comments in changes.

They were True before because it is set simply we can use static cache without torch.compile. But after #34247, static is kind tied to torch.compile and we should say it works if it works with torch.compile

@@ -330,6 +330,8 @@ def forward(self, hidden_states):
) # [num_tokens, num_experts]
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
expert_size = gates.long().sum(0) # [num_experts,]
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
# (and `DataDependentOutputException`)
expert_size = expert_size.tolist()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jimba has this line expert_size = expert_size.tolist() too and it has no _supports_static_cache = True. Let do the same for this model.

@@ -1155,7 +1156,7 @@ def forward(
elif position_ids is None:
position_ids = cache_position.unsqueeze(0)

if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail torch compile with another different type error.

@ydshieh ydshieh changed the title Set some Remove _supports_static_cache = True for some model classes Nov 27, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jan 9, 2025

kindly ping @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice PR, thanks for trying to fix these !

@@ -868,6 +868,8 @@ def forward(
)
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# Fill in zeros for cross_attention hidden_states of tokens attending to no images
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: dynamic shape operator: aten.nonzero.default`)
# (set torch._dynamo.config.capture_dynamic_output_shape_ops = True may help but not tested)
hidden_states[cross_attention_gate == 0] = hidden_states[cross_attention_gate == 0].fill_(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm I am wondering if using torch.masked_fill would be better here and would avoid the graph break?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using

hidden_states.masked_fill((cross_attention_gate == 0)[:, :, None], 0.0)

seems avoid the failure here, but I got another failure

E               torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
E
E               from user code:
E                  File "/transformers/src/transformers/models/idefics/modeling_idefics.py", line 1609, in forward
E                   outputs = self.model(
E                 File "/transformers/src/transformers/models/idefics/modeling_idefics.py", line 1315, in forward
E                   layer_outputs = vblock(
E                 File "/transformers/src/transformers/models/idefics/modeling_idefics.py", line 1277, in vblock
E                   layer_outputs = main_block(
E                 File "/transformers/src/transformers/models/idefics/modeling_idefics.py", line 719, in forward
E                   hidden_states, self_attn_weights, present_key_value = self.self_attn(
E                 File "/transformers/src/transformers/models/idefics/modeling_idefics.py", line 611, in forward
E                   cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
E                 File "/transformers/src/transformers/models/idefics/modeling_idefics.py", line 431, in forward
E                   if seq_len > self.max_seq_len_cached:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can find if seq_len > self.max_seq_len_cached: in many modeling files, like llama, but they are used in

        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

which is not something in idefics.

Anyway, I will update the line to use mask_fill

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Updated with your suggestion, thanks a lot! Just as mentioned above, couldn't compile with other errors 😢

@ArthurZucker ArthurZucker removed their request for review January 16, 2025 15:51
@ydshieh ydshieh requested a review from ArthurZucker January 23, 2025 17:04
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

Comment on lines +333 to +334
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
# (and `DataDependentOutputException`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, to list is obviously wrong!

@@ -868,7 +868,7 @@ def forward(
)
hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# Fill in zeros for cross_attention hidden_states of tokens attending to no images
hidden_states[cross_attention_gate == 0] = hidden_states[cross_attention_gate == 0].fill_(0)
hidden_states = hidden_states.masked_fill((cross_attention_gate == 0)[:, :, None], 0.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jan 27, 2025

run-slow: idefics

Copy link

This comment contains run-slow, running the specified jobs: ['models/idefics'] ...

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jan 28, 2025

Failing tests are irrelevant to this PR and already failing on main. Merge PR now.

@ydshieh ydshieh merged commit bf16a18 into main Jan 28, 2025
17 of 18 checks passed
@ydshieh ydshieh deleted the fix_compile_3 branch January 28, 2025 09:42
bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Jan 31, 2025
…gface#34975)

* use mask_fill

* remove comment

---------

Co-authored-by: ydshieh <[email protected]>
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
…gface#34975)

* use mask_fill

* remove comment

---------

Co-authored-by: ydshieh <[email protected]>
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
…gface#34975)

* use mask_fill

* remove comment

---------

Co-authored-by: ydshieh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants